3  Experimental Results

This is a supplementary appendix to the research paper Endogenous Macrodynamics in Algorithmic Recourse. It contains all of the experimental results, including those not highlighted in the actual paper. It also links to additional information about the proposed mitigation strategies.

4 Synthetic Data

This notebook was used to run the experiments for the synthetic datasets and can be used to reproduce the results in the paper. In the following we first run the experiments and then generate visualizations and tables.

4.1 Experiment

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(opt = opt),
    :REVISE=>REVISEGenerator(opt = opt),
    :DICE=>DiCEGenerator(opt = opt),
)
max_obs = 1000
catalogue = load_synthetic(max_obs)
choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
data_sets = filter(p -> p[1] in choices, catalogue)
experiments = set_up_experiments(data_sets,models,generators)
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_test_before.png"))
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.train_data),digits=2)
        plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
        x_wrongly_labelled = exp_.train_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_train_before.png"))
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
T = 100
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
end
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        data = rec_sys[m].data
        score = round(model_evaluation(M, data),digits=2)
        plt = plot(M, data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
        x_wrongly_labelled = data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
end

4.2 Plots

results = Serialization.deserialize(joinpath(output_path,"results.jls"));
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

4.2.1 Line Charts

Figure 4.1 shows the evolution of the evaluation metrics over the course of the experiment.

img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(Images.load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 4.1: Line Charts

4.2.2 Error Bar Charts

Figure 4.2 shows the evaluation metrics at the end of the experiments.

img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(Images.load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 4.2: Error Bar Charts

4.3 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))

4.4 Chart in paper

Figure 4.3 shows the chart that went into the paper.

using DataFrames, Statistics
df = results[:overlapping].output
df = df[[x  maximum(df.n) for x in df.n],:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:decisiveness, :disagreement, :mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="decisiveness" ? "Decisiveness" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="disagreement" ? "Disagreement" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="REVISE" ? "Latent" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

scale_ = 1.5
R"""
library(data.table)
df_plot <- data.table($df_plot)
name_order <- c(
    "MMD (domain)",
    "MMD (model)",
    "Performance",
    "Disagreement",
    "Decisiveness"
)
df_plot[,name:=factor(name, levels=name_order)]
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot(df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange(aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=c(0.9,0.9)), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    )
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path, width=$ncol * $scale_,height=$nrow * $scale_ * 0.75) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
Images.load(joinpath(www_path,"paper_synthetic_results.png"))

Figure 4.3: Chart in paper

# echo: false

generate_artifacts(output_path)
generate_artifacts(www_path)

4.5 Real-World Data

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(opt = opt),
    :REVISE=>REVISEGenerator(opt = opt),
    :DICE=>DiCEGenerator(opt = opt),
)
max_obs = 5000
data_sets = load_real_world(max_obs)
choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
data_sets = filter(p -> p[1] in choices, data_sets)
using CounterfactualExplanations.DataPreprocessing: unpack
bs = 500
function data_loader(data::CounterfactualData)
    X, y = unpack(data)
    data = Flux.DataLoader((X,y),batchsize=bs)
    return data
end
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1)
experiments = set_up_experiments(
    data_sets,models,generators; 
    pre_train_models=100, model_params=model_params, 
    data_loader=data_loader
)

4.5.1 Experiment

n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_samples = 10000
T = 100
generative_model_params = (epochs=250, latent_dim=8)
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
    generative_model_params=generative_model_params
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)

4.5.2 Plots

results = Serialization.deserialize(joinpath(output_path,"results.jls"))
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

4.5.2.1 Line Charts

Figure 4.4 shows the evolution of the evaluation metrics over the course of the experiment.

img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.4: Line Charts

4.5.2.2 Error Bar Charts

Figure 4.5 shows the evaluation metrics at the end of the experiments.

img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.5: Error Bar Charts

4.5.3 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))

4.5.4 Chart in paper

Figure 4.6 shows the chart that went into the paper.

using DataFrames, Statistics
model_ = :FluxEnsemble
df = DataFrame() 
for (key, val) in results
    df_ = deepcopy(val.output)
    df_.dataset .= key
    df = vcat(df,df_)
end
df = df[df.n .== maximum(df.n),:]
df = df[df.model .== model_,:]
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
df_plot.name .= [r[:name] == "mmd" ? "$(r[:name])_$(r[:scope])" : r[:name] for r in eachrow(df_plot)]
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
transform!(df_plot, :name => (X -> [x=="mmd_domain" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_model" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="REVISE" ? "Latent" : x for x in X]) => :generator)

ncol = length(unique(df_plot.dataset))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 1.75
R"""
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(dataset), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    )
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
Images.load(joinpath(www_path,"paper_real_world_results.png"))

Figure 4.6: Chart in paper

4.6 Mitigation Strategies

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Generic=>GenericGenerator(opt = opt, decision_threshold=0.5),
    :Latent=>REVISEGenerator(opt = opt),
    :Generic_conservative=>GenericGenerator(opt = opt, decision_threshold=0.9),
    :Gravitational=>GravitationalGenerator(opt = opt),
    :ClapROAR=>ClapROARGenerator(opt = opt)
)

4.6.1 Synthetic

max_obs = 1000
catalogue = load_synthetic(max_obs)
choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
data_sets = filter(p -> p[1] in choices, catalogue)
experiments = set_up_experiments(data_sets,models,generators)
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
T = 100
using Serialization
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
)
Serialization.serialize(joinpath(output_path,"results_synthetic.jls"),results)

4.6.2 Plots

using Serialization
results = Serialization.deserialize(joinpath(output_path,"results_synthetic.jls"))
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

4.6.2.1 Line Charts

Figure 4.7 shows the evolution of the evaluation metrics over the course of the experiment.

choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& .!contains.(readdir(www_path),"latent")]
img_files = img_files[Bool.(reduce(+, map(choice -> contains.(img_files, string(choice)), choices)))]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Circles

(c) Credit Default

(d) GMSC

Figure 4.7: Line Charts

4.6.2.2 Error Bar Charts

Figure 4.8 shows the evaluation metrics at the end of the experiments.

choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& .!contains.(readdir(www_path),"latent")]
img_files = img_files[Bool.(reduce(+, map(choice -> contains.(img_files, string(choice)), choices)))]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Circles

(c) Credit Default

(d) GMSC

Figure 4.8: Error Bar Charts

4.6.3 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_synthetic.csv"))

4.6.4 Chart in paper

Figure 4.9 shows the chart that went into the paper.

using DataFrames, Statistics
df = results[:overlapping].output
df = df[df.n .== maximum(df.n),:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 2.0
R"""
library(data.table)
df_plot <- data.table($df_plot)
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=3))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_results.png"), img)
Images.load(joinpath(www_path,"paper_synthetic_results.png"))

Figure 4.9: Chart in paper

4.6.6 Plots

4.6.6.1 Line Charts

Figure 4.10 shows the evolution of the evaluation metrics over the course of the experiment.

img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart") .&& contains.(readdir(www_path),"latent")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 4.10: Line Charts

4.6.6.2 Error Bar Charts

Figure 4.11 shows the evaluation metrics at the end of the experiments.

img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart") .&& contains.(readdir(www_path),"latent")]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 4.11: Error Bar Charts

4.6.7 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_latent.csv"))

4.6.8 Chart in paper

Figure 4.12 shows the chart that went into the paper.

using DataFrames, Statistics
df = results[:overlapping].output
df = df[df.n .== maximum(df.n),:]
gdf = groupby(df, [:generator, :model, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :mmd_grid, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.==:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (domain)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="mmd_grid" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Latent" ? "Latent (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Latent_conservative" ? "Latent (γ=0.9)" : x for x in X]) => :generator)
transform!(df_plot, :model => (X -> [x=="FluxEnsemble" ? "Deep Ensemble" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="FluxModel" ? "MLP" : x for x in X]) => :model)
transform!(df_plot, :model => (X -> [x=="LogisticRegression" ? "Linear" : x for x in X]) => :model)

ncol = length(unique(df_plot.model))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 1.9
R"""
library(data.table)
df_plot <- data.table($df_plot)
model_order <- c("Linear", "MLP", "Deep Ensemble")
df_plot[,model:=factor(model, levels=model_order)]
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(model), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=4))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.8) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_synthetic_latent_results.png"), img)
Images.load(joinpath(www_path,"paper_synthetic_latent_results.png"))

Figure 4.12: Chart in paper

4.6.9 Real World

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Generic=>GenericGenerator(opt = opt, decision_threshold=0.5),
    :Latent=>REVISEGenerator(opt = opt),
    :Generic_conservative=>GenericGenerator(opt = opt, decision_threshold=0.9),
    :Gravitational=>GravitationalGenerator(opt = opt),
    :ClapROAR=>ClapROARGenerator(opt = opt)
)
max_obs = 2500
data_path = data_dir("real_world")
data_sets = load_real_world(max_obs)
choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
data_sets = filter(p -> p[1] in choices, data_sets)
using CounterfactualExplanations.DataPreprocessing: unpack
bs = 500
function data_loader(data::CounterfactualData)
    X, y = unpack(data)
    data = Flux.DataLoader((X,y),batchsize=bs)
    return data
end
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1)
experiments = set_up_experiments(
    data_sets,models,generators; 
    pre_train_models=100, model_params=model_params, 
    data_loader=data_loader
)
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_samples = 10000
T = 100
generative_model_params = (epochs=250, latent_dim=8)
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
    generative_model_params=generative_model_params
)
Serialization.serialize(joinpath(output_path,"results_real_world.jls"),results)
using Serialization
results = Serialization.deserialize(joinpath(output_path,"results_real_world.jls"))
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

4.6.9.1 Line Charts

Figure 4.4 shows the evolution of the evaluation metrics over the course of the experiment.

choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
img_files = readdir(www_path)[contains.(readdir(www_path),"line_chart")]
img_files = img_files[Bool.(reduce(+, map(choice -> contains.(img_files, string(choice)), choices)))]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.13: Line Charts

4.6.9.2 Error Bar Charts

Figure 4.5 shows the evaluation metrics at the end of the experiments.

choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
img_files = readdir(www_path)[contains.(readdir(www_path),"errorbar_chart")]
img_files = img_files[Bool.(reduce(+, map(choice -> contains.(img_files, string(choice)), choices)))]
img_files = joinpath.(www_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.14: Error Bar Charts

4.6.9.3 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap_real_world.csv"))

4.6.9.4 Chart in paper

Figure 4.12 shows the chart that went into the paper.

using DataFrames, Statistics
model_ = :FluxModel
df = DataFrame() 
for (key, val) in results
    df_ = deepcopy(val.output)
    df_.dataset .= key
    df = vcat(df,df_)
end
df = df[df.n .== maximum(df.n),:]
df = df[df.model .== model_,:]
filter!(:value => x -> !any(f -> f(x), (ismissing, isnothing, isnan)), df)
gdf = groupby(df, [:generator, :dataset, :n, :name, :scope])
df_plot = combine(gdf, :value => (x -> [(mean(x),mean(x)+std(x),mean(x)-std(x))]) => [:mean, :ymax, :ymin])
df_plot = df_plot[[name in [:mmd, :model_performance] for name in df_plot.name],:]
df_plot = df_plot[.!(df_plot.name.==:mmd .&& df_plot.scope.!=:model),:]
df_plot = mapcols(x -> typeof(x) == Vector{Symbol} ? string.(x) : x, df_plot)
transform!(df_plot, :dataset => (X -> [x=="cal_housing" ? "California Housing" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="credit_default" ? "Credit Default" : x for x in X]) => :dataset)
transform!(df_plot, :dataset => (X -> [x=="gmsc" ? "GMSC" : x for x in X]) => :dataset)
transform!(df_plot, :name => (X -> [x=="mmd" ? "MMD (model)" : x for x in X]) => :name)
transform!(df_plot, :name => (X -> [x=="model_performance" ? "Performance" : x for x in X]) => :name)
transform!(df_plot, :generator => (X -> [x=="Generic" ? "Generic (γ=0.5)" : x for x in X]) => :generator)
transform!(df_plot, :generator => (X -> [x=="Generic_conservative" ? "Generic (γ=0.9)" : x for x in X]) => :generator)

ncol = length(unique(df_plot.dataset))
nrow = length(unique(df_plot.name))

using RCall
scale_ = 2.0
R"""
library(ggplot2)
plt <- ggplot($df_plot) +
    geom_bar(aes(x=n, y=mean, fill=generator), stat="identity", alpha=0.5, position="dodge") +
    geom_pointrange( aes(x=n, y=mean, ymin=ymin, ymax=ymax, colour=generator), alpha=0.9, position=position_dodge(width=0.9), size=0.5) +
    facet_grid(
        rows = vars(name),
        cols =  vars(dataset), 
        scales = "free_y"
    ) +
    labs(y = "Value") + 
    scale_fill_discrete(name="Generator:") +
    scale_colour_discrete(name="Generator:") +
    theme(
        axis.title.x=element_blank(),
        axis.text.x=element_blank(),
        axis.ticks.x=element_blank(),
        legend.position="bottom"
    ) +
    guides(fill=guide_legend(ncol=3))
temp_path <- file.path(tempdir(), "plot.png")
ggsave(temp_path,width=$ncol * $scale_,height=$nrow * $scale_ * 0.85) 
"""

img = Images.load(rcopy(R"temp_path"))
Images.save(joinpath(www_path,"paper_real_world_results.png"), img)
Images.load(joinpath(www_path,"paper_real_world_results.png"))

Figure 4.15: Chart in paper